
import torch
import torch.nn as nn
import os
import subprocess
import logging
import time
import random
import torch
from typing import Tuple
import torch.nn.functional as F
from transformers import AutoTokenizer
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import scipy
from datetime import timedelta
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix


def fisher_matrix_diag(model, dataloader):
    '''
        Compute Fisher Matrix for EWC
    '''
    # Compute
    model.train()
    ori_device = None
    if hasattr(model.encoder,'device'):
        ori_device = model.encoder.device
    model.cuda()

    # Init
    fisher={}
    for n,p in model.named_parameters():
        if 'encoder' in n:
            fisher[n]=0*p.data
    
    total_cnt = 0
    for idx, X, y in dataloader:
        batch_size = len(y.flatten())
        total_cnt += batch_size
        X, y =X.cuda(), y.cuda()
        # Forward and backward
        model.zero_grad()
        logits = model.forward(X)
        loss, _, _ = model.batch_loss(logits, y)
        loss.backward()
        # Get gradients
        for n,p in model.named_parameters():
            if 'encoder' in n and p.grad is not None:
                fisher[n]+=batch_size*p.grad.data.pow(2)
    # Mean
    for n,_ in model.named_parameters():
        if 'encoder' in n:
            fisher[n]=fisher[n]/total_cnt
            fisher[n]=torch.autograd.Variable(fisher[n],requires_grad=False)
    if ori_device is not None:
        model.to(ori_device)

    return fisher

def get_center(X, Y, num_class=None):
    '''
        Compute the class center of X, 
        Note that this function is suitable for all classes computation,
        A better implementation will be compute_class_feature_center

        Params:
            - X : a tensor has dims (num_samples, hidden_dims)
            - Y : a tensor has dims (num_samples), and the label indexes are from 0 ~ D-1
            - num_class: an integer indicates the label range (0 ~ num_class-1)
        Return:
            - class_center: a tensor has dims (num_seen_class, hidden_dims)
            - class_seen_mask: a list  has dims (num_class) 
            and it represents the seen classes mask
    '''
    # ensure X and Y in the same divice
    X_device = X.device
    Y = Y.to(X_device)

    # set the number of classes
    if num_class==None:
        num_class = int(X.shape[1])

    # get the mask for the class whose center can be calculated
    class_seen_mask = [True if i in Y else False for i in range(num_class)] 
    class_unseen_mask = [True if i not in Y else False for i in range(num_class)]
    num_class_unseen = int(np.sum(class_unseen_mask)) 

    # add dummy samples for the unseen class
    unseen_class_index = torch.where(torch.tensor(class_unseen_mask))[0].to(X_device)
    Y = torch.cat((Y, unseen_class_index))
    unseen_class_X = torch.zeros((num_class_unseen,X.shape[1])).to(X_device)
    X = torch.cat((X, unseen_class_X), dim=0)

    # convert to one-hot label
    Y = torch.eye(num_class)[Y.long()].to(X_device)

    # get center for all classes
    class_center = torch.matmul(torch.matmul(torch.diag(1/torch.sum(Y, dim=0)),Y.T),X)
    class_center = class_center[class_seen_mask,:]

    return class_center, class_seen_mask

def compute_class_feature_center(dataloader, model, select_class_indexes, is_normalize=True):
    '''
        Get features and targets

        Params:
            - dataloader: torch.utils.data.DataLoader
            - model: a model
            - select_class_indexes: a list of selected classes indexes (e.g. [1,2] or [0,1,2])
            - is_normalize: if normalize the features
        Return:
            - class_center_matrix: a tensor has dims (num_class, hidden_dims) each row is the center of a class
    '''
    features_matrix, y_list = compute_feature_by_dataloader(dataloader, 
                                                            model,
                                                            is_normalize=is_normalize)

    class_center_list= [] 
    for class_idx in select_class_indexes:
        class_mask = torch.eq(y_list, class_idx)
        class_center_list.append(torch.mean(features_matrix[class_mask],dim=0,keepdim=True))
    class_center_matrix = torch.cat(class_center_list, dim=0)

    return class_center_matrix

def compute_feature_by_dataloader(dataloader, model, is_normalize=False, return_idx=False):
    '''
        Compute the feature of dataloader{(X, Y)}, X has dims (num_sentences, num_words)

        Params:
            - dataloader: torch.utils.data.DataLoader
            - model: a model
            - is_normalize: if normalize the features
            - return_idx: if return idx
        Return:
            - features_matrix: a tensor has dim (num_samples, hidden_dims)
            - y_list: a tensor has dim (num_samples,)
    '''

    if hasattr(model,'device'):
        device = model.device
    else:
        device=None
    model.cuda()
    model.eval()

    idx_list = []
    features_list = []
    y_list = []
    with torch.no_grad():
        for idx, inputs, targets in dataloader:
            inputs = inputs.cuda()
            inputs_feature = model.forward_encoder(inputs).cpu()
            if return_idx:
                if isinstance(idx,tuple):
                    idx = np.array(idx)
                elif isinstance(idx,torch.Tensor):
                    idx = idx.cpu().numpy()
                idx_list.append(idx)
            features_list.append(inputs_feature.reshape(-1,inputs_feature.shape[-1]).numpy())
            y_list.append(targets.flatten().cpu().numpy())

    model.train()
    if device is not None:
        model.to(device)
    if return_idx:
        idx_list = np.concatenate(idx_list)

    features_matrix = np.concatenate(features_list, axis=0)
    y_list = np.concatenate(y_list)

    if is_normalize:
        features_matrix = features_matrix/np.linalg.norm(features_matrix, ord=2, axis=-1, keepdims=True)

    if return_idx:
        return idx_list, features_matrix, y_list

    return features_matrix, y_list

def get_match_id(feature_matrix: torch.Tensor, top_k: int=5, 
                max_samples: int =10000, metric: str='euclidean', 
                largest: bool=False) -> Tuple[torch.Tensor, torch.Tensor]:
    """
        Compute the nearest samples id of each sample,

        Params:
            - feature_matrix (num_samples, hidden_dims): the matrix of last hidden states
            - top_k: for each sample, return the id of the top_k nearest samples
            - max_samples: number of maximum samples for computation to avoid "out of memory"
            - metirc: 'euclidean' or 'cosine'
            - largest: if selecting the samples have the largest distances (default=False)
        Return:
            - knn_dist_matrix (num_samples, top_k): the distance of the knn
            - knn_id_matrix (num_samples, top_k): the index of the knn
    """
    num_samples_all = feature_matrix.shape[0]
    if metric == 'euclidean':
        if num_samples_all>max_samples:
            # 2.1. calculate the L2 distance inside z0
            dist_z =  scipy.spatial.distance.cdist(feature_matrix,
                                    feature_matrix[:max_samples],
                                    'euclidean')
            dist_z = torch.tensor(dist_z)
            # 2.2. calculate distance mask: do not use itself
            mask_input = torch.clamp(torch.ones_like(dist_z)-torch.eye(num_samples_all, max_samples), 
                                    min=0)
        else:
            # 2.1. calculate the L2 distance inside z0
            dist_z = pdist(feature_matrix, squared=False)
            
            # 2.2. calculate distance mask: do not use itself
            mask_input = torch.clamp(torch.ones_like(dist_z)-torch.eye(num_samples_all), min=0)
    elif metric == 'cosine':
        feature_matrix = feature_matrix/torch.norm(feature_matrix, dim=-1).reshape(-1,1)
        if num_samples_all>max_samples:
            dist_z = torch.matmul(feature_matrix,feature_matrix[:max_samples,:].T)
            dist_z = torch.ones_like(dist_z) - dist_z
            mask_input = torch.clamp(torch.ones_like(dist_z)-torch.eye(num_samples_all, max_samples), 
                                    min=0)
        else:
            dist_z = torch.matmul(feature_matrix,feature_matrix.T)
            dist_z = torch.ones_like(dist_z) - dist_z
            mask_input = torch.clamp(torch.ones_like(dist_z)-torch.eye(num_samples_all), min=0)
    else:
        raise Exception('Invalid metric %s'%(metric))
    # 2.3 find the image meets label requirements with nearest old feature
    dist_z = mask_input.float() * dist_z
    dist_z[mask_input == 0] = float("inf")
    knn_dist_matrix, knn_id_matrix = torch.topk(dist_z, top_k, largest=largest, dim=1)

    # Show the average distance
    # distance_mean = torch.mean(dist_z,dim=1).reshape(-1,1)
    # topk_value = torch.topk(dist_z, k=top_k, largest=False, dim=1)[0][:,1:] # (num_samples, topk)
    # topk_ratio = topk_value/distance_mean
    # print(torch.mean(topk_ratio,dim=0))

    return knn_dist_matrix, knn_id_matrix

def pdist(e, squared: bool=False, eps: float=1e-12) -> torch.Tensor:
    """
        Compute the L2 distance of all features

        Params:
            - e: a feature matrix has dims (num_samples, hidden_dims)
            - squared: if return the squared results of the distance
            - eps: the threshold to avoid negative distance
        Return:
            - res: a distance matrix has dims (num_samples, num_samples)
    """
    e_square = e.pow(2).sum(dim=1)
    prod = e @ e.t()
    res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

    if not squared:
        res = res.sqrt()

    res = res.clone()
    res[range(len(e)), range(len(e))] = 0
    return res

def plot_embedding(X, Y):
    '''
        Plot the feature X

        Params:
            - X: a feature matrix has dims (num_samples, hidden_dims)
            - Y: a label list has dims (num_samples)
    '''
    plt.scatter(X[:,0], 
                X[:,1], 
                c=Y, 
                marker='.',
                cmap=plt.cm.Spectral)

def plot_centers(X, label_list):
    '''
        Plot the feature centers X

        Params:
            - X: a feature matrix has dims (num_classes, hidden_dims)
            - label_list: a list has dims (num_samples) 
            and it represents the name of each class
    '''
    plt.scatter(X[:,0], X[:,1], 
                c=[i+1 for i in range(X.shape[0])], 
                marker='*')
    for i, l_name in enumerate(label_list):
        plt.text(X[i,0], X[i,1],
                s=str(l_name),
                size=15)

def plot_distribution(X, Y, label_list, class_center_matrix=None, sample_ratio=1.0, select_labels=None):
    '''
        Visualize the feature X in the 2-D space

        Params:
            - X: a feature matrix has dims (num_samples, hidden_dims)
            - Y: a label list has dims (num_samples)
            - label_list: a list has dims (num_classes) 
            and it represents the name of each class
            - class_center_matrix: if not None, plot the class center of each class;
            it has dims (num_classes, hidden_dims)
            - sample_ratio: the ratio of the samples used for visualization
            - select_labels: a list represents the selected labels for visualization
    '''
    # clone and convert to tensor
    if isinstance(X, list):
        _X = torch.tensor(X)
    else:
        _X = X.clone().detach().cpu()
    if isinstance(Y, list):
        _Y = torch.tensor(Y)
    else:
        _Y = Y.clone().detach().cpu()
    num_samples = _Y.shape[0]
    print('Total %d samples for visualization'%num_samples)

    # random sampling
    if sample_ratio<1.0:
        assert sample_ratio>0.0, "Invalid sample ratio!!!"
        
        sample_lst = list(range(num_samples))
        random.shuffle(sample_lst)
        sample_lst = sample_lst[:int(num_samples*sample_ratio)]
        _X = _X[sample_lst]
        _Y = _Y[sample_lst]
        print('Select %d samples for visualization'%_Y.shape[0])

    if select_labels!=None and len(select_labels)>0:
        for i,l in enumerate(select_labels):
            if i==0:
                class_mask = np.equal(_Y,l)
            else:
                class_mask = np.logical_or(class_mask,np.equal(_Y,l))
        _Y = _Y[class_mask]
        _X = _X[class_mask]
    
    # t-SNE for visualization
    tsne = TSNE(n_components=2)
    if not class_center_matrix is None:
        assert len(label_list)==class_center_matrix.shape[0], "Number of classes is not consistent!!!"
        num_class = class_center_matrix.shape[0]
        concat_X = torch.cat((_X, class_center_matrix),dim=0)
        concat_low_repre = torch.tensor(tsne.fit_transform(concat_X))

        # scale to 0-1
        x_min, x_max = torch.min(concat_low_repre, 0)[0], torch.max(concat_low_repre, 0)[0]
        concat_low_repre = (concat_low_repre - x_min) / (x_max - x_min)

        low_repre = concat_low_repre[:-num_class,:]
        plot_embedding(low_repre, _Y)
        class_low_repre = concat_low_repre[-num_class:,:]
        plot_centers(class_low_repre, label_list)
        plt.show()
    else:
        low_repre = torch.tensor(tsne.fit_transform(_X))

        # scale to 0-1
        x_min, x_max = torch.min(low_repre, 0)[0], torch.max(low_repre, 0)[0]
        low_repre = (low_repre - x_min) / (x_max - x_min)

        plot_embedding(low_repre, _Y)
        plt.show()
    
def save_predicts_to_txt(x_list, y_list, pred_list, label_list, pred_file_name='pred_file.txt', pad_token_label_id=-100, backbone='bert-base-cased'):
    '''
        Save model predictions to txt file

        Params:
            - x_list: a tensor has dims (num_samples,)
            - y_list: a tensor has dims (num_samples,)
            - pred_list: a tensor has dims (num_samples,)
            - label_list: a list indicates the label list
            - pred_file_name: a path for the save file
            - pad_token_label_id: a index for padding label
            - backbone: a valid name for 'transformers'
    '''
    assert(x_list.size(0)==y_list.size(0))
    assert(x_list.size(0)==pred_list.size(0))
    
    tokenizer = AutoTokenizer.from_pretrained(backbone)

    with open(pred_file_name,"w",encoding='utf-8') as f:
        for X_index, gold_index, pred_index in zip(x_list, y_list, pred_list):
            gold_index = int(gold_index) 
            if gold_index != pad_token_label_id:
                pred_token = label_list[pred_index]
                gold_token = label_list[gold_index]
                X_word = str(tokenizer.decode(X_index))
                f.write(str(X_word)+"\t"+str(gold_token)+"\t"+str(pred_token)+"\n")

def plot_confusion_matrix(pred_list, y_list, label_list, pad_token_label_id=-100):
    '''
        Plot confusion matrix for model predictions

        Params:
            - pred_list: a tensor has dims (num_samples,)
            - y_list: a tensor has dims (num_samples,)
            - label_list: a list indicates the label list
            - pad_token_label_id: a index for padding label
    '''
    # filter out padding label
    pred_list, y_list = torch.tensor(pred_list), torch.tensor(y_list)
    pad_mask = torch.not_equal(y_list, pad_token_label_id)
    pred_list, y_list = pred_list[pad_mask], y_list[pad_mask]

    pred_list = list(pred_list.numpy())
    y_list = list(y_list.numpy())

    O_index = label_list.index('O')
    cm = confusion_matrix(y_list, pred_list)
    cm_without_o = np.concatenate((cm[:O_index,:],cm[O_index+1:,:]),axis=0)
    cm_without_o = np.concatenate((cm_without_o[:,:O_index],cm_without_o[:,O_index+1:]),axis=1)
    df = pd.DataFrame(cm_without_o,
                    columns=label_list[:O_index]+label_list[O_index+1:],
                    index=label_list[:O_index]+label_list[O_index+1:])
    cmap = sns.color_palette("mako", as_cmap=True)
    sns.heatmap(df, cmap=cmap, xticklabels=True, yticklabels=True,annot=True)
    plt.xticks(rotation=-45)
    plt.xlabel('Predict label')
    plt.ylabel('Actual label')
    plt.show()

def plot_prob_hist_each_class(y_list, logits_list, ignore_label_lst=[-100,0]):
    '''
        Plot probability histogram for each class

        Params:
            - y_list: a tensor has dims (num_samples,)
            - logits_list: a tensor has dims (num_samples, num_classes)
    '''
    
    pad_mask = torch.not_equal(y_list, -100)
    y_list, logits_list = y_list[pad_mask], logits_list[pad_mask]
    
    pred_list = torch.argmax(logits_list, dim=-1)
    prob_list = torch.softmax(logits_list, dim=-1)
    
    for label_id in list(set(np.array(y_list))):
        if label_id in ignore_label_lst:
            continue
        # print("label_id=%d:"%label_id)
        y_mask = torch.eq(pred_list, label_id)
        y_mask_correct = torch.logical_and(\
                            torch.eq(y_list, label_id),
                            y_mask)
        y_mask_wrong = torch.logical_and(\
                            torch.not_equal(y_list, label_id),
                            y_mask)
        y_logits_correct = np.array(prob_list[y_mask_correct][:,label_id])
        y_logits_wrong = np.array(prob_list[y_mask_wrong][:,label_id])
        print(len(y_logits_correct))
        print(len(y_logits_wrong))
        plt.hist([y_logits_correct,y_logits_wrong], 
                    bins=list(np.arange(0,0.9,0.1))\
                            +[0.9, 0.99, 0.999, 0.9999, 0.99999, 1],
                    color=['green','red'], 
                    alpha=0.75)
        plt.legend(['Correct','Wrong'])
        plt.title('Prob distribution for class idx %d'%label_id)
        plt.show()

def decode_sentence(sentence, auto_tokenizer):
    '''
        Decode the sentences batch from ids to words (string)

        Params:
            - sentence: a list of ids (encoded by the tokenizer)
            - auto_tokenizer: a tokenizer for the transformers
        Return:
            - sent_str: sentence string
    '''
    sent_str = ''
    for word_id in sentence:
        word = str(auto_tokenizer.decode(word_id))
        # skip the special tokens ['[PAD]','[CLS]','[SEP]','[UNK]','MASK']
        if word in ['[PAD]','[CLS]','[SEP]','[UNK]','MASK']: 
            continue
        # concat the subwords
        if word.find('##')==0: 
            sent_str = sent_str+word[2:]
        else:
            sent_str = sent_str+' '+word
    return sent_str

def decode_word_from_sentence(sentence, pos_idx, auto_tokenizer):
    '''
        Decode the i-th word from sentence

        Params:
            - sentence: a list of ids (encoded by the tokenizer)
            - pos_idx: the position index of the word
            - auto_tokenizer: a tokenizer for the transformers
        Returns:
            - word_str: a string of the selected word

    '''
    word_str = auto_tokenizer.decode(sentence[pos_idx])
    tmp_cnt = 1
    while len(sentence)>pos_idx+tmp_cnt:
        next_word = auto_tokenizer.decode(sentence[pos_idx+tmp_cnt])
        # skip the special tokens ['[PAD]','[CLS]','[SEP]','[UNK]','MASK']
        if next_word in ['[PAD]','[CLS]','[SEP]','[UNK]','MASK'] or next_word.find('##')!=0:
            break
        word_str = word_str + next_word[2:]
        tmp_cnt += 1
    return word_str

def init_experiment(params, logger_filename):
    '''
        Initialize the experiment, save parameters and create a logger

        Params:
            - params: a dict contains all hyper-parameters and experimental settings
            - logger_filename: the logger file name
    '''
    # create save path
    get_saved_path(params)

    # create a logger
    logger = create_logger(os.path.join(params.dump_path, logger_filename))
    logger.info('============ Initialized logger ============')
    logger.info('\n'.join('%s: %s' % (k, str(v))
                          for k, v in sorted(dict(vars(params)).items())))
    logger.info('The experiment will be stored in %s\n' % params.dump_path)

    return logger


class LogFormatter():
    '''
        A formatter adding date and time informations
    '''
    def __init__(self):
        self.start_time = time.time()

    def format(self, record):
        elapsed_seconds = round(record.created - self.start_time)

        prefix = "%s - %s - %s" % (
            record.levelname,
            time.strftime('%x %X'),
            timedelta(seconds=elapsed_seconds)
        )
        message = record.getMessage()
        message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
        return "%s - %s" % (prefix, message) if message else ''


def create_logger(filepath):
    '''
        Create logger for the experiment

        Params:
            - filepath: the path which the log file is saved
    '''
    # create log formatter
    log_formatter = LogFormatter()
    
    # create file handler and set level to debug
    if filepath is not None:
        file_handler = logging.FileHandler(filepath, "a")
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(log_formatter)

    # create console handler and set level to info
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(log_formatter)

    # create logger and set level to debug
    logger = logging.getLogger()
    logger.handlers = []
    logger.setLevel(logging.DEBUG)
    logger.propagate = False
    if filepath is not None:
        logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # reset logger elapsed time
    def reset_time():
        log_formatter.start_time = time.time()
    logger.reset_time = reset_time

    return logger


def get_saved_path(params):
    '''
        Create a directory to store the experiment

        Params:
            - params: a dict contains all hyper-parameters and experimental settings
    '''
    dump_path = "./" if params.dump_path == "" else params.dump_path
    if not os.path.isdir(dump_path):
        subprocess.Popen("mkdir -p %s" % dump_path, shell=True).wait()
    assert os.path.isdir(dump_path)

    # create experiment path if it does not exist
    exp_path = os.path.join(dump_path, params.wandb_name)
    if not os.path.exists(exp_path):
        subprocess.Popen("mkdir -p %s" % exp_path, shell=True).wait()
    
    # generate id for this experiment
    while True:
        run_id = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime(time.time()))
        dump_path = os.path.join(exp_path, run_id)
        if not os.path.isdir(dump_path):
            break
        else:
            time.sleep(1)
    
    params.dump_path = dump_path
    if not os.path.isdir(params.dump_path):
        subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait()

    assert os.path.isdir(params.dump_path)
